{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MLE and MAP Estimation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this short tutorial we review how to do Maximum Likelihood (MLE) and Maximum a Posteriori (MAP) estimation in Pyro."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.distributions import constraints\n",
    "import pyro\n",
    "import pyro.distributions as dist\n",
    "from pyro.infer import SVI, Trace_ELBO"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We consider the simple \"fair coin\" example covered in a [previous tutorial](http://pyro.ai/examples/svi_part_i.html#A-simple-example)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = torch.zeros(10)\n",
    "data[0:6] = 1.0\n",
    "\n",
    "def original_model(data):\n",
    "    f = pyro.sample(\"latent_fairness\", dist.Beta(10.0, 10.0))\n",
    "    with pyro.plate(\"data\", data.size(0)):\n",
    "        pyro.sample(\"obs\", dist.Bernoulli(f), obs=data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To facilitate comparison between different inference techniques, we construct a training helper:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, guide, lr=0.01):\n",
    "    pyro.clear_param_store()\n",
    "    adam = pyro.optim.Adam({\"lr\": lr})\n",
    "    svi = SVI(model, guide, adam, loss=Trace_ELBO())\n",
    "\n",
    "    n_steps = 101\n",
    "    for step in range(n_steps):\n",
    "        loss = svi.step(data)\n",
    "        if step % 50 == 0:\n",
    "            print('[iter {}]  loss: {:.4f}'.format(step, loss))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MLE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our model has a single latent variable `latent_fairness`. To do Maximum Likelihood Estimation we simply \"demote\" our latent variable `latent_fairness` to a Pyro parameter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_mle(data):\n",
    "    # note that we need to include the interval constraint; \n",
    "    # in original_model() this constraint appears implicitly in \n",
    "    # the support of the Beta distribution.\n",
    "    f = pyro.param(\"latent_fairness\", torch.tensor(0.5), \n",
    "                   constraint=constraints.unit_interval)\n",
    "    with pyro.plate(\"data\", data.size(0)):\n",
    "        pyro.sample(\"obs\", dist.Bernoulli(f), obs=data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since we no longer have any latent variables, our guide can be empty:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def guide_mle(data):\n",
    "    pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's see what result we get."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[iter 0]  loss: 6.9315\n",
      "[iter 50]  loss: 6.7310\n",
      "[iter 100]  loss: 6.7301\n"
     ]
    }
   ],
   "source": [
    "train(model_mle, guide_mle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Our MLE estimate of the latent fairness is 0.601\n"
     ]
    }
   ],
   "source": [
    "print(\"Our MLE estimate of the latent fairness is {:.3f}\".format(\n",
    "      pyro.param(\"latent_fairness\").item()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Thus with MLE we get a point estimate of `latent_fairness`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MAP"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With Maximum a Posteriori estimation, we also get a point estimate of our latent variables. The difference to MLE is that these estimates will be regularized by the prior."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To do MAP in Pyro we use a [Delta distribution](http://docs.pyro.ai/en/stable/distributions.html#pyro.distributions.Delta) for the guide. Recall that the `Delta` distribution puts all its probability mass at a single value. The `Delta` distribution will be parameterized by a learnable parameter. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def guide_map(data):\n",
    "    f_map = pyro.param(\"f_map\", torch.tensor(0.5),\n",
    "                       constraint=constraints.unit_interval)\n",
    "    pyro.sample(\"latent_fairness\", dist.Delta(f_map))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's see how this result differs from MLE."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[iter 0]  loss: 5.6719\n",
      "[iter 50]  loss: 5.6006\n",
      "[iter 100]  loss: 5.6004\n"
     ]
    }
   ],
   "source": [
    "train(original_model, guide_map)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Our MAP estimate of the latent fairness is 0.536\n"
     ]
    }
   ],
   "source": [
    "print(\"Our MAP estimate of the latent fairness is {:.3f}\".format(\n",
    "      pyro.param(\"f_map\").item()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To understand what's going on note that the prior mean of the `latent_fairness` in our model is 0.5, since that is the mean of `Beta(10.0, 10.0)`. The MLE estimate (which ignores the prior) gives us a result that is entirely determined by the raw counts (6 heads and 4 tails, say). In contrast the MAP estimate is regularized towards the prior mean, which is why the MAP estimate is somewhere between 0.5 and 0.6."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
